{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Complete Project Setup"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "from PIL import Image\n",
    "import requests\n",
    "from transformers import AutoProcessor, Blip2ForConditionalGeneration, BitsAndBytesConfig\n",
    "import torch\n",
    "import os\n",
    "import re\n",
    "from pycocoevalcap.bleu.bleu import Bleu\n",
    "from pycocoevalcap.rouge.rouge import Rouge\n",
    "from pycocoevalcap.cider.cider import Cider\n",
    "from pycocoevalcap.meteor.meteor import Meteor\n",
    "from pycocoevalcap.tokenizer.ptbtokenizer import PTBTokenizer\n",
    "import matplotlib.pyplot as plt\n",
    "from sentence_transformers import SentenceTransformer, util\n",
    "import numpy as np\n",
    "from tqdm import tqdm\n",
    "import io\n",
    "from peft import LoraConfig, get_peft_model\n",
    "from torch.utils.data import Dataset, DataLoader\n",
    "from PIL import Image\n",
    "import pandas as pd\n",
    "import datasets\n",
    "from sklearn.model_selection import train_test_split\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Import Model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
    "\n",
    "processor = AutoProcessor.from_pretrained(\"Salesforce/blip2-opt-2.7b\")\n",
    "\n",
    "model = Blip2ForConditionalGeneration.from_pretrained(\"ybelkada/blip2-opt-2.7b-fp16-sharded\", device_map=\"auto\", load_in_8bit=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Process Dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "split_size = 0.2 \n",
    "dataset_name = 'sd' # sd/coco"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Note: For dataset download instructions, please refer to the README.md file\n",
    "\n",
    "if dataset_name == 'sd':\n",
    "    datasets = pd.read_parquet('sd_ext_prompts.parquet')\n",
    "elif dataset_name == 'coco':\n",
    "    datasets = pd.read_parquet('cocoprompts.parquet')\n",
    "\n",
    "train_split, test_split = train_test_split(datasets, test_size=split_size, random_state=42)\n",
    "print(f\"Train dataset size: {len(train_split)}\")\n",
    "print(f\"Validation dataset size: {len(test_split)}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Convert parquet file to datasets format\n",
    "train_dataset = datasets.Dataset.from_pandas(train_split)  \n",
    "test_dataset = datasets.Dataset.from_pandas(test_split)      \n",
    "  "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "# View an example\n",
    "train_dataset[0]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Training Configuration"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "from torch.utils.data import Dataset, DataLoader\n",
    "from PIL import Image\n",
    "import os\n",
    "import pandas as pd\n",
    "\n",
    "class ImageCaptioningDataset(Dataset):\n",
    "    def __init__(self, dataset, processor):\n",
    "        self.dataset = dataset\n",
    "        self.processor = processor\n",
    "\n",
    "    def __len__(self):\n",
    "        return len(self.dataset)\n",
    "\n",
    "    def __getitem__(self, idx):\n",
    "        item = self.dataset[idx]\n",
    "        image = Image.open(item['image_path'])\n",
    "        encoding = self.processor(images=image, padding=\"max_length\", return_tensors=\"pt\")\n",
    "        # remove batch dimension\n",
    "        encoding = {k: v.squeeze() for k, v in encoding.items()}\n",
    "        if dataset_name == 'sd':\n",
    "            encoding[\"text\"] = item[\"Prompt\"]\n",
    "        elif dataset_name == 'coco':\n",
    "            encoding[\"text\"] = item[\"caption\"]\n",
    "        return encoding"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "def collate_fn(batch):\n",
    "    # pad the input_ids and attention_mask\n",
    "    processed_batch = {}\n",
    "    for key in batch[0].keys():\n",
    "        if key != \"text\":\n",
    "            processed_batch[key] = torch.stack([example[key] for example in batch])\n",
    "        else:\n",
    "            text_inputs = processor.tokenizer(\n",
    "                [example[\"text\"] for example in batch], padding=True, return_tensors=\"pt\"\n",
    "            )\n",
    "            processed_batch[\"input_ids\"] = text_inputs[\"input_ids\"]\n",
    "            processed_batch[\"attention_mask\"] = text_inputs[\"attention_mask\"]\n",
    "    return processed_batch"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Define the LoraConfig\n",
    "config = LoraConfig(\n",
    "    r=64,\n",
    "    lora_alpha=64,\n",
    "    lora_dropout=0.05,\n",
    "    bias=\"none\",\n",
    "    target_modules=[\"q_proj\", \"k_proj\", \"v_proj\", \"o_proj\"]\n",
    ")\n",
    "\n",
    "model = get_peft_model(model, config)\n",
    "model.print_trainable_parameters()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_ft_dataset = ImageCaptioningDataset(train_dataset, processor)\n",
    "train_dataloader = DataLoader(train_ft_dataset, batch_size=16, shuffle=True,collate_fn=collate_fn)\n",
    "\n",
    "subset_size = int(0.1 * len(test_dataset))\n",
    "test_ft_dataset = ImageCaptioningDataset(test_dataset.take(subset_size), processor)\n",
    "test_dataloader = DataLoader(test_ft_dataset, batch_size=16,shuffle=True,collate_fn=collate_fn)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Start Training and Save Best Model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "def generate_captions_batch(batch, model, processor):\n",
    "    generated_ids = model.generate(batch['pixel_values'], max_length=77)\n",
    "    captions = processor.batch_decode(generated_ids, skip_special_tokens=True)\n",
    "    return captions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sentence_transformers import SentenceTransformer, util\n",
    "import numpy as np\n",
    "import datetime\n",
    "\n",
    "judge_model = SentenceTransformer('all-MiniLM-L6-v2')\n",
    "\n",
    "\n",
    "def cal_cossim(reference:list,generated:list,judge_model)->dict:\n",
    "\n",
    "    scores = []\n",
    "\n",
    "    for seq in range(len(reference)):\n",
    "        embedding1 = judge_model.encode(reference[seq], convert_to_tensor=True)\n",
    "        embedding2 = judge_model.encode(generated[seq], convert_to_tensor=True)\n",
    "\n",
    "        cosine_similarity = util.pytorch_cos_sim(embedding1, embedding2)\n",
    "        scores.append(cosine_similarity.item())\n",
    "    \n",
    "    mean_score = np.mean(scores)\n",
    "\n",
    "    return {\"mean_score\":mean_score,\"scores\":scores}\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [],
   "source": [
    "def train_and_evaluate(model,processor, train_dataloader, test_dataloader, optimizer, epochs=10):\n",
    "    best_val_cos =0\n",
    "    train_loss = []\n",
    "    val_cos = []\n",
    "    now = datetime.datetime.now().strftime(\"%Y-%m-%dT%H-%M-%S\")\n",
    "\n",
    "\n",
    "    for epoch in range(epochs):\n",
    "        model.train()\n",
    "        total_loss = 0\n",
    "\n",
    "        print(f\"Epoch {epoch+1}\")\n",
    "\n",
    "        for batch in tqdm(train_dataloader,total=len(train_dataloader),desc=\"Train Progress\"):\n",
    "            input_ids = batch.pop(\"input_ids\").to(device)\n",
    "            pixel_values = batch.pop(\"pixel_values\").to(device, torch.float16)\n",
    "            attention_mask = batch.pop(\"attention_mask\").to(device)\n",
    "\n",
    "            optimizer.zero_grad()\n",
    "            outputs = model(input_ids=input_ids,\n",
    "                            pixel_values=pixel_values,\n",
    "                            labels=input_ids,\n",
    "                            attention_mask=attention_mask)\n",
    "        \n",
    "            loss = outputs.loss\n",
    "\n",
    "            #print(\"Loss:\", loss.item())\n",
    "\n",
    "            loss.backward()\n",
    "            optimizer.step()\n",
    "\n",
    "            total_loss += loss.item()\n",
    "        \n",
    "        avg_train_loss = total_loss / len(train_dataloader)\n",
    "        train_loss.append(avg_train_loss)\n",
    "        print(f\"Epoch {epoch+1} Train Loss: {avg_train_loss}\")\n",
    "\n",
    "        model.eval()\n",
    "\n",
    "        with torch.no_grad():\n",
    "            total_genval_captions = []\n",
    "            total_ref_captions = []\n",
    "            for batch in tqdm(test_dataloader,total=len(test_dataloader),desc=\"Evaluation Progress\"):\n",
    "                \n",
    "                captions = generate_captions_batch(batch, model, processor)\n",
    "                refs = processor.batch_decode(batch['input_ids'], skip_special_tokens=True)\n",
    "                total_genval_captions.extend(captions)\n",
    "                total_ref_captions.extend(refs)\n",
    "\n",
    "            cosscore_val = cal_cossim(reference=total_ref_captions,generated=total_genval_captions,judge_model=judge_model)\n",
    "\n",
    "        val_cos.append(cosscore_val['mean_score'])\n",
    "        print(f\"Validation Cosine-similarity: {cosscore_val['mean_score']}\")\n",
    "\n",
    "        if cosscore_val['mean_score'] >= best_val_cos:\n",
    "            best_val_cos = cosscore_val['mean_score']\n",
    "            print(f\"Saving new best model with val_cos: {best_val_cos}\")\n",
    "            model_path = f\"blip2_ft_{dataset_name}/{now}/blip2_{len(train_dataset)}_{split_size}_{epochs}/best\"\n",
    "            model.save_pretrained(model_path)\n",
    "            #processor.save_pretrained(model_path)\n",
    "            similarity_file = open(model_path+'/cosine_similarity.txt', 'w')\n",
    "            similarity_file.write('Epoch, Cosine Similarity\\n')  # 写入标题行\n",
    "            similarity_file.write(f'{epoch+1}, {best_val_cos}\\n')\n",
    "\n",
    "    return {\"bst_model_path\":model_path,\"train_avg_loss/e\":train_loss,\"val_avg_cos/e\":val_cos}\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [],
   "source": [
    "num_epoches = 10\n",
    "optimizer = torch.optim.Adam(model.parameters(), lr=5e-5)\n",
    "best_model_output = train_and_evaluate(model=model,processor=processor,train_dataloader=train_dataloader,test_dataloader=test_dataloader,optimizer=optimizer,epochs=num_epoches) "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Load best model\n",
    "# model = Blip2ForConditionalGeneration.from_pretrained(bast_model_output['bst_model_path'], device_map=\"auto\", load_in_8bit=True)\n",
    "# Load processor\n",
    "# processor = AutoProcessor.from_pretrained(bast_model_output['bst_model_path'])\n",
    "\n",
    "model.save_pretrained(f'blip2_di_{dataset_name}/final_model/blip2_{len(train_dataset)}_{split_size}_{num_epoches}',safe_serialization=False)\n",
    "processor.save_pretrained(f'blip2_di_{dataset_name}/final_model/blip2_{len(train_dataset)}_{split_size}_{num_epoches}')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Generate initial results with newly trained model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [],
   "source": [
    "import requests\n",
    "from PIL import Image\n",
    "from tqdm import tqdm\n",
    "import io\n",
    "\n",
    "def generate_cap(df,u_model,u_processor)->list:\n",
    "    u_model.eval()\n",
    "    generation = []\n",
    "    for index,row in tqdm(df.iterrows(),total = len(df) ):\n",
    "\n",
    "        image = Image.open(row['image_path']).convert('RGB')\n",
    "        inputs = u_processor(images=image, return_tensors=\"pt\").to(device)\n",
    "\n",
    "        pixel_values = inputs.pixel_values\n",
    "\n",
    "        generated_ids = u_model.generate(pixel_values=pixel_values)\n",
    "        generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip()\n",
    "\n",
    "        # generated_ids = model.generate(**inputs)\n",
    "        # generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip()\n",
    "        generation.append(generated_text)\n",
    "    \n",
    "    return generation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_generated_caption = generate_cap(test_df,model,processor)\n",
    "if dataset_name == 'sd':\n",
    "    test_reference_prompt = test_df['Prompt'].tolist()\n",
    "elif dataset_name == 'coco':\n",
    "    test_reference_prompt = test_df['caption'].tolist()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_compare_df = pd.DataFrame({'image_path':test_df['image_path'],'generated_caption':test_generated_caption,'reference_prompt':test_reference_prompt})\n",
    "test_compare_df.to_parquet(f'blip2_DI_{dataset_name}/result/blip2_{split_size}_{num_epoches}_test.parquet')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_compare_df"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "ldm",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.19"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
